{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import re\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from sklearn.model_selection import KFold\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_string(string):\n",
    "    string = re.sub('[^A-Za-z0-9\\-\\/ ]+', ' ', string).split()\n",
    "    return [y.strip() for y in string]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('pos-data-v3.json','r') as fopen:\n",
    "    dataset = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['뭘봐', '뭘봐', 'PROPN']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['ひ', 'ひ', 'PROPN']\n",
      "list index out of range ['ヒ', 'ヒ', 'PROPN']\n",
      "list index out of range ['形聲', '形聲', 'NOUN']\n",
      "list index out of range ['°', '°', 'SYM']\n",
      "list index out of range ['汉', '汉', 'PROPN']\n",
      "list index out of range ['东', '东', 'PROPN']\n",
      "list index out of range ['王', '王', 'PROPN']\n",
      "list index out of range ['（', '（', 'PROPN']\n",
      "list index out of range ['伊', '伊', 'PROPN']\n",
      "list index out of range ['）', '）', 'PROPN']\n",
      "list index out of range ['ȝ', 'ȝ', 'PROPN']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['°', '°', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range [\"'\", '_', 'PROPN']\n",
      "list index out of range ['碁', '碁', 'NOUN']\n",
      "list index out of range ['囲碁', '囲碁', 'NOUN']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['会', '会', 'PROPN']\n",
      "list index out of range ['蔡武侯', '蔡武侯', 'PROPN']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['+', '+', 'SYM']\n",
      "list index out of range ['+', '+', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['*', '*', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['+', '+', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['~', '~', 'ADJ']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['ОАО', 'оао', 'PROPN']\n",
      "list index out of range ['Газпром', 'газпром', 'PROPN']\n",
      "list index out of range [\"''\", '_', 'PROPN']\n",
      "list index out of range ['°', '°', 'SYM']\n",
      "list index out of range ['°', '°', 'SYM']\n",
      "list index out of range ['佐々木', '佐々木', 'PROPN']\n",
      "list index out of range ['功', '功', 'PROPN']\n",
      "list index out of range ['ɳ', 'ɳ', 'PROPN']\n",
      "list index out of range ['ʂ', 'ʂ', 'PROPN']\n",
      "list index out of range ['ʈ', 'ʈ', 'PROPN']\n",
      "list index out of range ['ɖ', 'ɖ', 'PROPN']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range [\"'\", '_', 'NOUN']\n",
      "list index out of range [\"'\", '_', 'PROPN']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['鉄の森', '鉄の森', 'PROPN']\n",
      "list index out of range ['アイゼンヴァルト', 'アイゼンヴァルト', 'PROPN']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['+', '+', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['壽', '壽', 'PROPN']\n",
      "list index out of range ['陽', '陽', 'PROPN']\n",
      "list index out of range ['裴叔業', '裴叔業', 'PROPN']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['#', '#', 'SYM']\n",
      "list index out of range ['+', '+', 'SYM']\n",
      "list index out of range ['فيروز', 'فيروز', 'PROPN']\n",
      "list index out of range ['الديلامي', 'الديلامي', 'PROPN']\n",
      "list index out of range ['נביאים', 'נביאים', 'NOUN']\n",
      "list index out of range ['ראשונים', 'ראשונים', 'NOUN']\n",
      "list index out of range ['נביאים', 'נביאים', 'NOUN']\n",
      "list index out of range ['جامع', 'جامع', 'PROPN']\n",
      "list index out of range ['الرئيس', 'الرئيس', 'PROPN']\n",
      "list index out of range ['الصالح', 'الصالح', 'PROPN']\n",
      "list index out of range ['—', '—', 'NOUN']\n",
      "list index out of range ['°', '°', 'SYM']\n",
      "list index out of range ['°', '°', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['ä', 'ä', 'PROPN']\n",
      "list index out of range ['ʈ', 'ʈ', 'PROPN']\n",
      "list index out of range ['ʈʰ', 'ʈʰ', 'PROPN']\n",
      "list index out of range ['°', '°', 'SYM']\n",
      "list index out of range ['°', '°', 'SYM']\n",
      "list index out of range [\"''\", '_', 'PROPN']\n",
      "list index out of range ['€', '€', 'SYM']\n",
      "list index out of range ['\\u200e', '\\u200e', 'VERB']\n",
      "list index out of range ['\\u200e', '\\u200e', 'NOUN']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['$', '$', 'SYM']\n",
      "list index out of range ['Α', 'α', 'PROPN']\n",
      "list index out of range ['А', 'а', 'PROPN']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['$', '$', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['おる', 'おる', 'PROPN']\n",
      "list index out of range ['いる', 'いる', 'PROPN']\n",
      "list index out of range ['じゃ', 'じゃ', 'PROPN']\n",
      "list index out of range ['や', 'や', 'PROPN']\n",
      "list index out of range ['だ', 'だ', 'PROPN']\n",
      "list index out of range ['～', '～', 'PROPN']\n",
      "list index out of range ['ん', 'ん', 'PROPN']\n",
      "list index out of range ['わから', 'わから', 'PROPN']\n",
      "list index out of range ['ん', 'ん', 'PROPN']\n",
      "list index out of range ['～', '～', 'PROPN']\n",
      "list index out of range ['ない', 'ない', 'PROPN']\n",
      "list index out of range ['わからない', 'わからない', 'PROPN']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['#', '#', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['#', '#', 'SYM']\n",
      "list index out of range ['#', '#', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['×', '×', 'NOUN']\n",
      "list index out of range ['°', '°', 'SYM']\n",
      "list index out of range ['–', '–', 'NUM']\n",
      "list index out of range ['°', '°', 'SYM']\n",
      "list index out of range ['続', '続', 'PROPN']\n",
      "list index out of range ['・', '・', 'PROPN']\n",
      "list index out of range ['三', '三', 'PROPN']\n",
      "list index out of range ['丁目', '丁目', 'PROPN']\n",
      "list index out of range ['の', 'の', 'PROPN']\n",
      "list index out of range ['夕日', '夕日', 'PROPN']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['Ἑλλάδος', 'ἑλλάδος', 'PROPN']\n",
      "list index out of range ['παίδευσις', 'παίδευσις', 'PROPN']\n",
      "list index out of range ['€', '€', 'SYM']\n",
      "list index out of range ['ê', 'ê', 'PROPN']\n",
      "list index out of range ['é', 'é', 'PROPN']\n",
      "list index out of range ['è', 'è', 'PROPN']\n",
      "list index out of range ['ë', 'ë', 'PROPN']\n",
      "list index out of range ['ē', 'ē', 'PROPN']\n",
      "list index out of range ['ĕ', 'ĕ', 'PROPN']\n",
      "list index out of range ['ě', 'ě', 'PROPN']\n",
      "list index out of range ['ẽ', 'ẽ', 'PROPN']\n",
      "list index out of range ['ė', 'ė', 'PROPN']\n",
      "list index out of range ['ę', 'ę', 'PROPN']\n",
      "list index out of range ['ẻ', 'ẻ', 'PROPN']\n",
      "list index out of range ['ö', 'ö', 'PROPN']\n",
      "list index out of range ['°', '°', 'SYM']\n",
      "list index out of range ['°', '°', 'SYM']\n",
      "list index out of range ['°', '°', 'SYM']\n",
      "list index out of range ['°', '°', 'SYM']\n",
      "list index out of range ['.', '.', 'PROPN']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['بندر', 'بندر', 'PROPN']\n",
      "list index out of range ['عباس', 'عباس', 'PROPN']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['+', '+', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['$', '$', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['珠海', '珠海', 'PROPN']\n",
      "list index out of range [\"''\", '_', 'NOUN']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['서종', '서종', 'PROPN']\n",
      "list index out of range ['제', '제', 'PROPN']\n",
      "list index out of range ['±', '±', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['仙', '仙', 'PROPN']\n",
      "list index out of range ['仚', '仚', 'PROPN']\n",
      "list index out of range ['僊', '僊', 'PROPN']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['腹切り', '腹切り', 'PROPN']\n",
      "list index out of range ['§', '§', 'PROPN']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['육식', '육식', 'PROPN']\n",
      "list index out of range ['동물', '동물', 'PROPN']\n",
      "list index out of range ['도망자', '도망자', 'PROPN']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['الغرب', 'الغرب', 'PROPN']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['سدوم', 'سدوم', 'PROPN']\n",
      "list index out of range ['عمورة', 'عمورة', 'PROPN']\n",
      "list index out of range ['高雄', '高雄', 'PROPN']\n",
      "list index out of range ['大眾', '大眾', 'PROPN']\n",
      "list index out of range ['捷', '捷', 'PROPN']\n",
      "list index out of range ['運', '運', 'PROPN']\n",
      "list index out of range ['系統', '系統', 'PROPN']\n",
      "list index out of range ['°', '°', 'SYM']\n",
      "list index out of range ['°', '°', 'SYM']\n",
      "list index out of range ['=', '=', 'SYM']\n",
      "list index out of range ['亳', '亳', 'PROPN']\n",
      "list index out of range ['°', '°', 'SYM']\n",
      "list index out of range ['المعقبات', 'المعقبات', 'PROPN']\n",
      "list index out of range ['%', '%', 'SYM']\n",
      "list index out of range ['#', '#', 'SYM']\n",
      "list index out of range ['#', '#', 'SYM']\n",
      "list index out of range ['#', '#', 'SYM']\n",
      "list index out of range ['#', '#', 'SYM']\n",
      "list index out of range ['%', '%', 'SYM']\n"
     ]
    }
   ],
   "source": [
    "texts, labels = [], []\n",
    "for i in dataset:\n",
    "    try:\n",
    "        texts.append(process_string(i[0])[0].lower())\n",
    "        labels.append(i[-1])\n",
    "    except Exception as e:\n",
    "        print(e, i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "word2idx = {'PAD': 0,'NUM':1,'UNK':2}\n",
    "tag2idx = {'PAD': 0}\n",
    "char2idx = {'PAD': 0}\n",
    "word_idx = 3\n",
    "tag_idx = 1\n",
    "char_idx = 1\n",
    "\n",
    "def parse_XY(texts, labels):\n",
    "    global word2idx, tag2idx, char2idx, word_idx, tag_idx, char_idx\n",
    "    X, Y = [], []\n",
    "    for no, text in enumerate(texts):\n",
    "        text = text.lower()\n",
    "        tag = labels[no]\n",
    "        for c in text:\n",
    "            if c not in char2idx:\n",
    "                char2idx[c] = char_idx\n",
    "                char_idx += 1\n",
    "        if tag not in tag2idx:\n",
    "            tag2idx[tag] = tag_idx\n",
    "            tag_idx += 1\n",
    "        Y.append(tag2idx[tag])\n",
    "        if text not in word2idx:\n",
    "            word2idx[text] = word_idx\n",
    "            word_idx += 1\n",
    "        X.append(text)\n",
    "    return X, np.array(Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, Y = parse_XY(texts, labels)\n",
    "idx2char = {idx: tag for tag, idx in char2idx.items()}\n",
    "idx2tag = {i: w for w, i in tag2idx.items()}\n",
    "onehot = np.zeros((Y.shape[0],len(tag2idx)))\n",
    "onehot[np.arange(Y.shape[0]),Y] = 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def str_idx(corpus, dic, UNK=0):\n",
    "    maxlen = max([len(i) for i in corpus])\n",
    "    X = np.zeros((len(corpus),maxlen))\n",
    "    for i in range(len(corpus)):\n",
    "        for no, k in enumerate(corpus[i][:maxlen][::-1]):\n",
    "            try:\n",
    "                X[i,-1 - no]=dic[k]\n",
    "            except Exception as e:\n",
    "                X[i,-1 - no]=UNK\n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_seq = str_idx(X,char2idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open('char-bidirectional-pos.json','w') as fopen:\n",
    "    fopen.write(json.dumps({'idx2tag':idx2tag, 'char2idx':char2idx,'tag2idx':tag2idx}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.cross_validation import train_test_split\n",
    "train_X, test_X, train_Y, test_Y = train_test_split(X_seq, onehot, test_size=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.2.2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "print(keras.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.models import Model, Input\n",
    "from keras.layers import LSTM, Embedding, Dense, Dropout, Bidirectional\n",
    "from keras.backend.tensorflow_backend import set_session\n",
    "set_session(tf.InteractiveSession())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_word = Input(shape=(None,))\n",
    "model = Embedding(input_dim=len(char2idx) + 1, output_dim=128, mask_zero=True)(input_word)\n",
    "model = Bidirectional(LSTM(units=128, return_sequences=False, recurrent_dropout=0.1))(model)\n",
    "out = Dense(len(tag2idx),activation='softmax')(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Model(input_word, out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(optimizer=\"adam\", loss='categorical_crossentropy', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_1 (InputLayer)         (None, None)              0         \n",
      "_________________________________________________________________\n",
      "embedding_1 (Embedding)      (None, None, 128)         5120      \n",
      "_________________________________________________________________\n",
      "bidirectional_1 (Bidirection (None, 256)               263168    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 16)                4112      \n",
      "=================================================================\n",
      "Total params: 272,400\n",
      "Trainable params: 272,400\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<svg height=\"264pt\" viewBox=\"0.00 0.00 277.00 264.00\" width=\"277pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 260)\">\n",
       "<title>G</title>\n",
       "<polygon fill=\"white\" points=\"-4,4 -4,-260 273,-260 273,4 -4,4\" stroke=\"none\"/>\n",
       "<!-- 140322986741096 -->\n",
       "<g class=\"node\" id=\"node1\"><title>140322986741096</title>\n",
       "<polygon fill=\"none\" points=\"72,-219.5 72,-255.5 197,-255.5 197,-219.5 72,-219.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"134.5\" y=\"-233.8\">input_1: InputLayer</text>\n",
       "</g>\n",
       "<!-- 140322986741152 -->\n",
       "<g class=\"node\" id=\"node2\"><title>140322986741152</title>\n",
       "<polygon fill=\"none\" points=\"54,-146.5 54,-182.5 215,-182.5 215,-146.5 54,-146.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"134.5\" y=\"-160.8\">embedding_1: Embedding</text>\n",
       "</g>\n",
       "<!-- 140322986741096&#45;&gt;140322986741152 -->\n",
       "<g class=\"edge\" id=\"edge1\"><title>140322986741096-&gt;140322986741152</title>\n",
       "<path d=\"M134.5,-219.313C134.5,-211.289 134.5,-201.547 134.5,-192.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"138,-192.529 134.5,-182.529 131,-192.529 138,-192.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140322919314152 -->\n",
       "<g class=\"node\" id=\"node3\"><title>140322919314152</title>\n",
       "<polygon fill=\"none\" points=\"0,-73.5 0,-109.5 269,-109.5 269,-73.5 0,-73.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"134.5\" y=\"-87.8\">bidirectional_1(lstm_1): Bidirectional(LSTM)</text>\n",
       "</g>\n",
       "<!-- 140322986741152&#45;&gt;140322919314152 -->\n",
       "<g class=\"edge\" id=\"edge2\"><title>140322986741152-&gt;140322919314152</title>\n",
       "<path d=\"M134.5,-146.313C134.5,-138.289 134.5,-128.547 134.5,-119.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"138,-119.529 134.5,-109.529 131,-119.529 138,-119.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140323054465320 -->\n",
       "<g class=\"node\" id=\"node4\"><title>140323054465320</title>\n",
       "<polygon fill=\"none\" points=\"83.5,-0.5 83.5,-36.5 185.5,-36.5 185.5,-0.5 83.5,-0.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"134.5\" y=\"-14.8\">dense_1: Dense</text>\n",
       "</g>\n",
       "<!-- 140322919314152&#45;&gt;140323054465320 -->\n",
       "<g class=\"edge\" id=\"edge3\"><title>140322919314152-&gt;140323054465320</title>\n",
       "<path d=\"M134.5,-73.3129C134.5,-65.2895 134.5,-55.5475 134.5,-46.5691\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"138,-46.5288 134.5,-36.5288 131,-46.5289 138,-46.5288\" stroke=\"black\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>"
      ],
      "text/plain": [
       "<IPython.core.display.SVG object>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython.display import SVG\n",
    "from keras.utils.vis_utils import model_to_dot\n",
    "\n",
    "SVG(model_to_dot(model).create(prog='dot', format='svg'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 83767 samples, validate on 9308 samples\n",
      "Epoch 1/2\n",
      "83767/83767 [==============================] - 147s 2ms/step - loss: 0.9461 - acc: 0.6823 - val_loss: 0.6229 - val_acc: 0.7899\n",
      "Epoch 2/2\n",
      "83767/83767 [==============================] - 146s 2ms/step - loss: 0.5629 - acc: 0.8105 - val_loss: 0.5081 - val_acc: 0.8290\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(train_X, train_Y, batch_size=32, epochs=2,\n",
    "                    validation_split=0.1, verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "predicted = model.predict(test_X)\n",
    "labels = [i[0] for i in sorted(tag2idx.items(), key=lambda t: t[1])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "             precision    recall  f1-score   support\n",
      "\n",
      "       NOUN       0.74      0.82      0.78      2709\n",
      "        ADP       0.95      0.96      0.95      1202\n",
      "        NUM       0.92      0.92      0.92       448\n",
      "       VERB       0.92      0.90      0.91      1183\n",
      "      PROPN       0.76      0.73      0.74      2233\n",
      "      CCONJ       0.98      0.94      0.96       375\n",
      "        ADJ       0.71      0.47      0.56       453\n",
      "       PRON       0.93      0.94      0.94       528\n",
      "        ADV       0.84      0.75      0.79       470\n",
      "        AUX       0.99      1.00      1.00       124\n",
      "        DET       0.90      0.92      0.91       397\n",
      "      SCONJ       0.64      0.76      0.69       143\n",
      "       PART       0.77      0.89      0.82        55\n",
      "        SYM       1.00      0.85      0.92        20\n",
      "          X       0.00      0.00      0.00         2\n",
      "\n",
      "avg / total       0.83      0.83      0.82     10342\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/sklearn/metrics/classification.py:1135: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples.\n",
      "  'precision', 'predicted', average, warn_for)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "print(classification_report(np.argmax(test_Y,1), np.argmax(predicted,1), target_names=labels[1:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_weights('char-bidirectional-pos.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
